import torch
import cv2
import numpy as np
from CustomNet import CustomNet

# 加载模型
device = torch.device("cpu")
model = torch.load("./numbers.pth", weights_only=False)
model = model.to(device)
# -------------- 使用随机图片测试模型--------------
image = cv2.imread("./test/8.png")
image = np.expand_dims(image, 0)
image = torch.from_numpy(image)
image = torch.permute(image, [0, 3, 1, 2])
predict = model(image.float())
result = torch.argmax(predict, dim=-1)
print(result)
